{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 提交PyTorch分布式作业\n",
    "\n",
    "\n",
    "PAI支持用户提交分布式PyTorch训练作业，本文将介绍如何使用PAI Python SDK，以[PyTorch DDP(DistributedDataParallel)](https://pytorch.org/docs/stable/notes/ddp.html)模式提交分布式PyTorch训练作业。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 费用说明\n",
    "\n",
    "本示例将会使用以下云产品，并产生相应的费用账单：\n",
    "\n",
    "- PAI-DLC：运行训练任务，详细计费说明请参考[PAI-DLC计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-dlc)\n",
    "- OSS：存储训练任务输出的模型、训练代码等，详细计费说明请参考[OSS计费概述](https://help.aliyun.com/zh/oss/product-overview/billing-overview)\n",
    "\n",
    "\n",
    "> 通过参与云产品免费试用，使用**指定资源机型**提交训练作业或是部署推理服务，可以免费试用PAI产品，具体请参考[PAI免费试用](https://help.aliyun.com/zh/pai/product-overview/free-quota-for-new-users)。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 安装和配置SDK\n",
    "\n",
    "我们需要首先安装PAI Python SDK以运行本示例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "skip-execution"
    ]
   },
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade pai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "!python -m pip install pygments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "SDK需要配置访问阿里云服务需要的AccessKey，以及当前使用的工作空间和OSS Bucket。在PAI SDK安装之后，通过在 **命令行终端** 中执行以下命令，按照引导配置密钥、工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```\n",
    "\n",
    "我们可以通过以下代码验证配置是否已生效。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "\n",
    "sess = get_default_session()\n",
    "\n",
    "# 获取配置的工作空间信息\n",
    "assert sess.workspace_name is not None\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyToch 分布式作业介绍\n",
    "\n",
    "[PyTorch DDP（Distributed Data Parallel）](https://pytorch.org/docs/stable/notes/ddp.html)是PyTorch提供的分布式数据并行训练功能，支持模型在多台机器上进行并行训练，从而提高训练效率。\n",
    "\n",
    "PyTorch DDP基于多进程的方案实现，支持单机多卡模式和多机多卡模式。在单机多卡模式下，用户可以使用同一台机器下的多个GPU来加速模型的训练。在多机多卡模式下，可以将计算任务分配到多台机器上进行并行计算，加速训练速度。对于DDP的详细介绍，可以参考PyTorch的[官方文档链接](https://pytorch.org/docs/stable/notes/ddp.html)。\n",
    "\n",
    "\n",
    "![PyTorch DDP](./resource/ddp.png)\n",
    "\n",
    "> PyTorch提供的`DataParallel`和`DistributedDataParallel`模块都支持数据并行训练，[PyTorch官方](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#comparison-between-dataparallel-and-distributeddataparallel)推荐不论是单机多卡还是多机多卡，都使用`DistributedDataParallel`模块进行训练。\n",
    "\n",
    "### 代码适配DDP改造\n",
    "\n",
    "使用PyTorch DDP进行分布式训练需要对原先的PyTorch训练代码（使用单机单卡）进行的修改，具体可以参考[PyTorch官方文档](https://pytorch.org/tutorials/beginner/dist_overview.html#torch-nn-parallel-distributeddataparallel)。\n",
    "\n",
    "主要包括：\n",
    "\n",
    "- 初始化分布式训练配置:\n",
    "\n",
    "需要在训练迭代开始之前完成训练环境初始化。\n",
    "\n",
    "```python\n",
    "\n",
    "from torch.distributed import init_process_group, destroy_process_group\n",
    "\n",
    "def ddp_setup()\n",
    "    init_process_group(backend=\"nccl\")\n",
    "\n",
    "```\n",
    "\n",
    "初始化需要指定机器之间的通讯方式，当使用GPU进行训练时，通常使用`nccl`作为通讯后端，而使用CPU训练时，使用`gloo`，详细的介绍可以参考PyTorch文档: [Which Backend To Use?](https://pytorch.org/docs/stable/distributed.html#which-backend-to-use)\n",
    "\n",
    "- 使用DDP封装模型：\n",
    "\n",
    "```python\n",
    "\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "\n",
    "# model是原始单机单卡训练的PyTorch模型\n",
    "model = DDP(model)\n",
    "\n",
    "```\n",
    "\n",
    "\n",
    "- 修改DataLoader的采样方式：\n",
    "\n",
    "当使用DDP进行数据并行训练，不同的worker进程需要读取不同的数据分片进行训练。当不同机器上通过共享存储的方式使用同一份数据集时，可以使用`torch.utils.data.distributed.DistributedSampler`来对数据进行采样，从而保证不同的worker进程读取不同的数据分片。\n",
    "\n",
    "```python\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "\n",
    "train_sampler = DistributedSampler(\n",
    "\ttrain_dataset,\n",
    "\tshuffle=True)\n",
    "\n",
    "trainloader = DataLoader(\n",
    "\ttrain_dataset,\n",
    "\tbatch_size=args.per_device_train_batch_size,\n",
    "\tsampler=train_sampler,\n",
    "\tnum_workers=2,\n",
    "\tdrop_last=True)\n",
    "\n",
    "```\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### PAI支持PyTorch DDP分布式训练\n",
    "\n",
    "PAI原生支持的PyTorch的分布式训练，当用户提交训练作业，指定作业类型为PyTorch训练作业时(`job_type=\"PyTorchJob\"`)，PAI的训练服务会在机器节点上设置环境变量，包含作业机器数量，机器RANK，机器之间的通讯地址等信息。\n",
    "\n",
    "| 环境变量名 | \t描述 |\n",
    "|:----------|:---------|\n",
    "|MASTER_ADDR | Master机器节点的服务地址 |\n",
    "|MASTER_PORT | Master机器节点端口号，如：23456 |\n",
    "|WORLD_SIZE\t | 分布式作业的**机器节点总数**，例如提交的训练作业申请了4台机器，则WORLD_ISZE=4 |\n",
    "|RANK\t| **机器节点的RANK**，例如启动了一个4个节点的作业，则各个机器节点的RANK分别为0,1,2,3 |\n",
    "\n",
    "\n",
    "`PyTorch`提供了分布式训练启动工具，`torchrun`(PyTorch 1.1.0及以上版本) 和 `torch.distributed.launch`(PyTorch 1.1.0版本以下)，支持训练作业的拉起。配合以上PAI预置的环境变量，我们可以便利得启动分布式训练作业。\n",
    "\n",
    "\n",
    "\n",
    "使用`torch.distributed.launch`拉起训练作业示例：\n",
    "\n",
    "```shell\n",
    "\n",
    "# for PyTorch<1.1.0\n",
    "\n",
    "python -m torch.distributed.launch \\\n",
    "--nproc_per_node=<NumberOrProcessPerNode> \\\n",
    "--master_addr=$MASTER_ADDR \\\n",
    "--master_port=$MASTER_PORT \\\n",
    "--nnodes=$WORLD_SIZE \\\n",
    "--node_rank=$RANK \\\n",
    "<YourTrainingScript> training_arguments...\n",
    "\n",
    "```\n",
    "\n",
    "使用`torchrun`拉起训练作业示例：\n",
    "\n",
    "```shell\n",
    "\n",
    "# for PyTorch>=1.1.0\n",
    "torchrun \\\n",
    "--nproc_per_node=<NumberOrProcessPerNode> \\\n",
    "--master_addr=$MASTER_ADDR \\\n",
    "--master_port=$MASTER_PORT \\\n",
    "--nnodes=$WORLD_SIZE \\\n",
    "--node_rank=$RANK \\\n",
    "<YourTrainingScript> training_arguments...\n",
    "\n",
    "```\n",
    "\n",
    "用户需要修改`<NumberOfProcessPerNode`为每一个机器节点需要启动的进程数，通常设置为机器节点的GPU数量。\n",
    "\n",
    "\n",
    "> 以上的作业启动命令，同样适用于单机多卡的训练作业启动。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 提交训练作业\n",
    "\n",
    "PAI Python SDK提供了Estimator的接口，用于提交训练作业，以下示例中，我们将通过Estimator提交一个PyTorch分布式训练作业。\n",
    "\n",
    "\n",
    "- 准备训练代码\n",
    "\n",
    "PyTorch提供了多机多卡的[训练代码示例](https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py)，在修改了模型和checkpoints保存路径后，我们既可以将其用于提交到PAI进行训练。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[38;2;61;123;123;03m#  Copyright 2023 Alibaba, Inc. or its affiliates.\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  you may not use this file except in compliance with the License.\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  You may obtain a copy of the License at\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#       https://www.apache.org/licenses/LICENSE-2.0\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  Unless required by applicable law or agreed to in writing, software\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  distributed under the License is distributed on an \"AS IS\" BASIS,\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  See the License for the specific language governing permissions and\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  limitations under the License.\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  Licensed under the Apache License, Version 2.0 (the \"License\");\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  you may not use this file except in compliance with the License.\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  You may obtain a copy of the License at\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#       https://www.apache.org/licenses/LICENSE-2.0\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  Unless required by applicable law or agreed to in writing, software\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  distributed under the License is distributed on an \"AS IS\" BASIS,\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  See the License for the specific language governing permissions and\u001b[39;00m\n",
      "\u001b[38;2;61;123;123;03m#  limitations under the License.\u001b[39;00m\n",
      "\n",
      "\u001b[38;2;61;123;123;03m# source: https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py\u001b[39;00m\n",
      "\u001b[38;2;0;128;0;01mimport\u001b[39;00m \u001b[38;2;0;0;255;01mos\u001b[39;00m\n",
      "\n",
      "\u001b[38;2;0;128;0;01mimport\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\n",
      "\u001b[38;2;0;128;0;01mimport\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mnn\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mfunctional\u001b[39;00m \u001b[38;2;0;128;0;01mas\u001b[39;00m \u001b[38;2;0;0;255;01mF\u001b[39;00m\n",
      "\u001b[38;2;0;128;0;01mfrom\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mutils\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mdata\u001b[39;00m \u001b[38;2;0;128;0;01mimport\u001b[39;00m Dataset, DataLoader\n",
      "\n",
      "\u001b[38;2;0;128;0;01mfrom\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mutils\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mdata\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mdistributed\u001b[39;00m \u001b[38;2;0;128;0;01mimport\u001b[39;00m DistributedSampler\n",
      "\u001b[38;2;0;128;0;01mfrom\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mutils\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mdata\u001b[39;00m \u001b[38;2;0;128;0;01mimport\u001b[39;00m Dataset\n",
      "\u001b[38;2;0;128;0;01mfrom\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mnn\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mparallel\u001b[39;00m \u001b[38;2;0;128;0;01mimport\u001b[39;00m DistributedDataParallel \u001b[38;2;0;128;0;01mas\u001b[39;00m DDP\n",
      "\u001b[38;2;0;128;0;01mfrom\u001b[39;00m \u001b[38;2;0;0;255;01mtorch\u001b[39;00m\u001b[38;2;0;0;255;01m.\u001b[39;00m\u001b[38;2;0;0;255;01mdistributed\u001b[39;00m \u001b[38;2;0;128;0;01mimport\u001b[39;00m init_process_group, destroy_process_group\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mclass\u001b[39;00m \u001b[38;2;0;0;255;01mMyTrainDataset\u001b[39;00m(Dataset):\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m__init__\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, size):\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39msize \u001b[38;2;102;102;102m=\u001b[39m size\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mdata \u001b[38;2;102;102;102m=\u001b[39m [(torch\u001b[38;2;102;102;102m.\u001b[39mrand(\u001b[38;2;102;102;102m20\u001b[39m), torch\u001b[38;2;102;102;102m.\u001b[39mrand(\u001b[38;2;102;102;102m1\u001b[39m)) \u001b[38;2;0;128;0;01mfor\u001b[39;00m _ \u001b[38;2;170;34;255;01min\u001b[39;00m \u001b[38;2;0;128;0mrange\u001b[39m(size)]\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m__len__\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m):\n",
      "        \u001b[38;2;0;128;0;01mreturn\u001b[39;00m \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39msize\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m__getitem__\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, index):\n",
      "        \u001b[38;2;0;128;0;01mreturn\u001b[39;00m \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mdata[index]\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255mddp_setup\u001b[39m():\n",
      "    init_process_group(backend\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mnccl\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "    torch\u001b[38;2;102;102;102m.\u001b[39mcuda\u001b[38;2;102;102;102m.\u001b[39mset_device(\u001b[38;2;0;128;0mint\u001b[39m(os\u001b[38;2;102;102;102m.\u001b[39menviron[\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mLOCAL_RANK\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m]))\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mclass\u001b[39;00m \u001b[38;2;0;0;255;01mTrainer\u001b[39;00m:\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m__init__\u001b[39m(\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m,\n",
      "        model: torch\u001b[38;2;102;102;102m.\u001b[39mnn\u001b[38;2;102;102;102m.\u001b[39mModule,\n",
      "        train_data: DataLoader,\n",
      "        optimizer: torch\u001b[38;2;102;102;102m.\u001b[39moptim\u001b[38;2;102;102;102m.\u001b[39mOptimizer,\n",
      "        save_every: \u001b[38;2;0;128;0mint\u001b[39m,\n",
      "        output_model_path: \u001b[38;2;0;128;0mstr\u001b[39m,\n",
      "        checkpoint_path: \u001b[38;2;0;128;0mstr\u001b[39m,\n",
      "    ) \u001b[38;2;102;102;102m-\u001b[39m\u001b[38;2;102;102;102m>\u001b[39m \u001b[38;2;0;128;0;01mNone\u001b[39;00m:\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mlocal_rank \u001b[38;2;102;102;102m=\u001b[39m \u001b[38;2;0;128;0mint\u001b[39m(os\u001b[38;2;102;102;102m.\u001b[39menviron[\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mLOCAL_RANK\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m])\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mglobal_rank \u001b[38;2;102;102;102m=\u001b[39m \u001b[38;2;0;128;0mint\u001b[39m(os\u001b[38;2;102;102;102m.\u001b[39menviron[\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mRANK\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m])\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel \u001b[38;2;102;102;102m=\u001b[39m model\u001b[38;2;102;102;102m.\u001b[39mto(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mlocal_rank)\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mtrain_data \u001b[38;2;102;102;102m=\u001b[39m train_data\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39moptimizer \u001b[38;2;102;102;102m=\u001b[39m optimizer\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39msave_every \u001b[38;2;102;102;102m=\u001b[39m save_every\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mepochs_run \u001b[38;2;102;102;102m=\u001b[39m \u001b[38;2;102;102;102m0\u001b[39m\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39moutput_model_path \u001b[38;2;102;102;102m=\u001b[39m output_model_path\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mcheckpoint_path \u001b[38;2;102;102;102m=\u001b[39m checkpoint_path\n",
      "        \u001b[38;2;0;128;0;01mif\u001b[39;00m os\u001b[38;2;102;102;102m.\u001b[39mpath\u001b[38;2;102;102;102m.\u001b[39mexists(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mget_snapshot_path()):\n",
      "            \u001b[38;2;0;128;0mprint\u001b[39m(\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mLoading snapshot\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "            \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39m_load_snapshot(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mget_snapshot_path())\n",
      "\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel \u001b[38;2;102;102;102m=\u001b[39m DDP(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel, device_ids\u001b[38;2;102;102;102m=\u001b[39m[\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mlocal_rank])\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255mget_snapshot_path\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m):\n",
      "        \u001b[38;2;0;128;0;01mreturn\u001b[39;00m os\u001b[38;2;102;102;102m.\u001b[39mpath\u001b[38;2;102;102;102m.\u001b[39mjoin(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mcheckpoint_path, \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mmodel.pt\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m_load_snapshot\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, snapshot_path):\n",
      "        loc \u001b[38;2;102;102;102m=\u001b[39m \u001b[38;2;186;33;33mf\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mcuda:\u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00m\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mlocal_rank\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m\"\u001b[39m\n",
      "        snapshot \u001b[38;2;102;102;102m=\u001b[39m torch\u001b[38;2;102;102;102m.\u001b[39mload(snapshot_path, map_location\u001b[38;2;102;102;102m=\u001b[39mloc)\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel\u001b[38;2;102;102;102m.\u001b[39mload_state_dict(snapshot[\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mMODEL_STATE\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m])\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mepochs_run \u001b[38;2;102;102;102m=\u001b[39m snapshot[\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mEPOCHS_RUN\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m]\n",
      "        \u001b[38;2;0;128;0mprint\u001b[39m(\u001b[38;2;186;33;33mf\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mResuming training from snapshot at Epoch \u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00m\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mepochs_run\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m_run_batch\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, source, targets):\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39moptimizer\u001b[38;2;102;102;102m.\u001b[39mzero_grad()\n",
      "        output \u001b[38;2;102;102;102m=\u001b[39m \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel(source)\n",
      "        loss \u001b[38;2;102;102;102m=\u001b[39m F\u001b[38;2;102;102;102m.\u001b[39mcross_entropy(output, targets)\n",
      "        loss\u001b[38;2;102;102;102m.\u001b[39mbackward()\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39moptimizer\u001b[38;2;102;102;102m.\u001b[39mstep()\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m_run_epoch\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, epoch):\n",
      "        b_sz \u001b[38;2;102;102;102m=\u001b[39m \u001b[38;2;0;128;0mlen\u001b[39m(\u001b[38;2;0;128;0mnext\u001b[39m(\u001b[38;2;0;128;0miter\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mtrain_data))[\u001b[38;2;102;102;102m0\u001b[39m])\n",
      "        \u001b[38;2;0;128;0mprint\u001b[39m(\n",
      "            \u001b[38;2;186;33;33mf\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m[GPU-\u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00m\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mglobal_rank\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m] Epoch \u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00mepoch\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m | Batchsize: \u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00mb_sz\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m | Steps: \u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00m\u001b[38;2;0;128;0mlen\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mtrain_data)\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m\"\u001b[39m\n",
      "        )\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mtrain_data\u001b[38;2;102;102;102m.\u001b[39msampler\u001b[38;2;102;102;102m.\u001b[39mset_epoch(epoch)\n",
      "        \u001b[38;2;0;128;0;01mfor\u001b[39;00m source, targets \u001b[38;2;170;34;255;01min\u001b[39;00m \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mtrain_data:\n",
      "            source \u001b[38;2;102;102;102m=\u001b[39m source\u001b[38;2;102;102;102m.\u001b[39mto(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mlocal_rank)\n",
      "            targets \u001b[38;2;102;102;102m=\u001b[39m targets\u001b[38;2;102;102;102m.\u001b[39mto(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mlocal_rank)\n",
      "            \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39m_run_batch(source, targets)\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m_save_snapshot\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, epoch):\n",
      "        snapshot \u001b[38;2;102;102;102m=\u001b[39m {\n",
      "            \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mMODEL_STATE\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m: \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel\u001b[38;2;102;102;102m.\u001b[39mmodule\u001b[38;2;102;102;102m.\u001b[39mstate_dict(),\n",
      "            \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mEPOCHS_RUN\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m: epoch,\n",
      "        }\n",
      "        torch\u001b[38;2;102;102;102m.\u001b[39msave(snapshot, \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mget_snapshot_path())\n",
      "        \u001b[38;2;0;128;0mprint\u001b[39m(\u001b[38;2;186;33;33mf\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mEpoch \u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00mepoch\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m | Training snapshot saved at \u001b[39m\u001b[38;2;164;90;119;01m{\u001b[39;00m\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39moutput_model_path\u001b[38;2;164;90;119;01m}\u001b[39;00m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255m_save_model\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m):\n",
      "        torch\u001b[38;2;102;102;102m.\u001b[39msave(\n",
      "            \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mmodel\u001b[38;2;102;102;102m.\u001b[39mstate_dict(), os\u001b[38;2;102;102;102m.\u001b[39mpath\u001b[38;2;102;102;102m.\u001b[39mjoin(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39moutput_model_path, \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mmodel.pt\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "        )\n",
      "\n",
      "    \u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255mtrain\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m, max_epochs: \u001b[38;2;0;128;0mint\u001b[39m):\n",
      "        \u001b[38;2;0;128;0;01mfor\u001b[39;00m epoch \u001b[38;2;170;34;255;01min\u001b[39;00m \u001b[38;2;0;128;0mrange\u001b[39m(\u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mepochs_run, max_epochs):\n",
      "            \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39m_run_epoch(epoch)\n",
      "            \u001b[38;2;0;128;0;01mif\u001b[39;00m \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39mglobal_rank \u001b[38;2;102;102;102m==\u001b[39m \u001b[38;2;102;102;102m0\u001b[39m \u001b[38;2;170;34;255;01mand\u001b[39;00m epoch \u001b[38;2;102;102;102m%\u001b[39m \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39msave_every \u001b[38;2;102;102;102m==\u001b[39m \u001b[38;2;102;102;102m0\u001b[39m:\n",
      "                \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39m_save_snapshot(epoch)\n",
      "        \u001b[38;2;61;123;123;03m# save model after training\u001b[39;00m\n",
      "        \u001b[38;2;0;128;0mself\u001b[39m\u001b[38;2;102;102;102m.\u001b[39m_save_model()\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255mload_train_objs\u001b[39m():\n",
      "    train_set \u001b[38;2;102;102;102m=\u001b[39m MyTrainDataset(\u001b[38;2;102;102;102m2048\u001b[39m)  \u001b[38;2;61;123;123;03m# load your dataset\u001b[39;00m\n",
      "    model \u001b[38;2;102;102;102m=\u001b[39m torch\u001b[38;2;102;102;102m.\u001b[39mnn\u001b[38;2;102;102;102m.\u001b[39mLinear(\u001b[38;2;102;102;102m20\u001b[39m, \u001b[38;2;102;102;102m1\u001b[39m)  \u001b[38;2;61;123;123;03m# load your model\u001b[39;00m\n",
      "    optimizer \u001b[38;2;102;102;102m=\u001b[39m torch\u001b[38;2;102;102;102m.\u001b[39moptim\u001b[38;2;102;102;102m.\u001b[39mSGD(model\u001b[38;2;102;102;102m.\u001b[39mparameters(), lr\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;102;102;102m1e-3\u001b[39m)\n",
      "    \u001b[38;2;0;128;0;01mreturn\u001b[39;00m train_set, model, optimizer\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255mprepare_dataloader\u001b[39m(dataset: Dataset, batch_size: \u001b[38;2;0;128;0mint\u001b[39m):\n",
      "    \u001b[38;2;0;128;0;01mreturn\u001b[39;00m DataLoader(\n",
      "        dataset,\n",
      "        batch_size\u001b[38;2;102;102;102m=\u001b[39mbatch_size,\n",
      "        pin_memory\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0;01mTrue\u001b[39;00m,\n",
      "        shuffle\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0;01mFalse\u001b[39;00m,\n",
      "        sampler\u001b[38;2;102;102;102m=\u001b[39mDistributedSampler(dataset),\n",
      "    )\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mdef\u001b[39;00m \u001b[38;2;0;0;255mmain\u001b[39m(\n",
      "    save_every: \u001b[38;2;0;128;0mint\u001b[39m,\n",
      "    total_epochs: \u001b[38;2;0;128;0mint\u001b[39m,\n",
      "    batch_size: \u001b[38;2;0;128;0mint\u001b[39m,\n",
      "    output_model_path: \u001b[38;2;0;128;0mstr\u001b[39m,\n",
      "    checkpoint_path: \u001b[38;2;0;128;0mstr\u001b[39m,\n",
      "):\n",
      "    ddp_setup()\n",
      "\n",
      "    dataset, model, optimizer \u001b[38;2;102;102;102m=\u001b[39m load_train_objs()\n",
      "    train_data \u001b[38;2;102;102;102m=\u001b[39m prepare_dataloader(dataset, batch_size)\n",
      "    trainer \u001b[38;2;102;102;102m=\u001b[39m Trainer(\n",
      "        model, train_data, optimizer, save_every, output_model_path, checkpoint_path\n",
      "    )\n",
      "    trainer\u001b[38;2;102;102;102m.\u001b[39mtrain(total_epochs)\n",
      "    destroy_process_group()\n",
      "\n",
      "\n",
      "\u001b[38;2;0;128;0;01mif\u001b[39;00m \u001b[38;2;25;23;124m__name__\u001b[39m \u001b[38;2;102;102;102m==\u001b[39m \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m__main__\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m:\n",
      "    \u001b[38;2;0;128;0;01mimport\u001b[39;00m \u001b[38;2;0;0;255;01margparse\u001b[39;00m\n",
      "\n",
      "    parser \u001b[38;2;102;102;102m=\u001b[39m argparse\u001b[38;2;102;102;102m.\u001b[39mArgumentParser(description\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33msimple distributed training job\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "    parser\u001b[38;2;102;102;102m.\u001b[39madd_argument(\n",
      "        \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m--total_epochs\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m, \u001b[38;2;0;128;0mtype\u001b[39m\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0mint\u001b[39m, help\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mTotal epochs to train the model\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\n",
      "    )\n",
      "    parser\u001b[38;2;102;102;102m.\u001b[39madd_argument(\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m--save_every\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m, \u001b[38;2;0;128;0mtype\u001b[39m\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0mint\u001b[39m, help\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mHow often to save a snapshot\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m)\n",
      "    parser\u001b[38;2;102;102;102m.\u001b[39madd_argument(\n",
      "        \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m--batch_size\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m,\n",
      "        default\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;102;102;102m32\u001b[39m,\n",
      "        \u001b[38;2;0;128;0mtype\u001b[39m\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0mint\u001b[39m,\n",
      "        help\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mInput batch size on each device (default: 32)\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m,\n",
      "    )\n",
      "    \u001b[38;2;61;123;123;03m# 使用PAI训练服务设置的环境变量，表示模型保存路径\u001b[39;00m\n",
      "    parser\u001b[38;2;102;102;102m.\u001b[39madd_argument(\n",
      "        \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m--output_model_path\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m,\n",
      "        default\u001b[38;2;102;102;102m=\u001b[39mos\u001b[38;2;102;102;102m.\u001b[39menviron\u001b[38;2;102;102;102m.\u001b[39mget(\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mPAI_OUTPUT_MODEL\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m),\n",
      "        \u001b[38;2;0;128;0mtype\u001b[39m\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0mstr\u001b[39m,\n",
      "        help\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mOutput model path\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m,\n",
      "    )\n",
      "    \u001b[38;2;61;123;123;03m# 使用PAI训练服务设置的环境变量，表示checkpoints保存路径\u001b[39;00m\n",
      "    parser\u001b[38;2;102;102;102m.\u001b[39madd_argument(\n",
      "        \u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33m--checkpoint_path\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m,\n",
      "        default\u001b[38;2;102;102;102m=\u001b[39mos\u001b[38;2;102;102;102m.\u001b[39menviron\u001b[38;2;102;102;102m.\u001b[39mget(\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mPAI_OUTPUT_CHECKPOINTS\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m),\n",
      "        \u001b[38;2;0;128;0mtype\u001b[39m\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;0;128;0mstr\u001b[39m,\n",
      "        help\u001b[38;2;102;102;102m=\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m\u001b[38;2;186;33;33mcheckpoints path\u001b[39m\u001b[38;2;186;33;33m\"\u001b[39m,\n",
      "    )\n",
      "    args \u001b[38;2;102;102;102m=\u001b[39m parser\u001b[38;2;102;102;102m.\u001b[39mparse_args()\n",
      "\n",
      "    main(\n",
      "        args\u001b[38;2;102;102;102m.\u001b[39msave_every,\n",
      "        args\u001b[38;2;102;102;102m.\u001b[39mtotal_epochs,\n",
      "        args\u001b[38;2;102;102;102m.\u001b[39mbatch_size,\n",
      "        args\u001b[38;2;102;102;102m.\u001b[39moutput_model_path,\n",
      "        args\u001b[38;2;102;102;102m.\u001b[39mcheckpoint_path,\n",
      "    )\n"
     ]
    }
   ],
   "source": [
    "# 通过以下代码查看准备提交的训练代码\n",
    "!pygmentize train_src/train_multinode.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 提交训练作业\n",
    "\n",
    "我们将使用PAI提供的PyTorch 1.12版本的GPU镜像完成多机多卡的作业训练。使用`estimator.fit`提交训练作业之后，SDK会打印作业的控制台链接，用户可以通过控制台查看作业状态，日志详情等信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "# 使用PAI提供的PyTorch 1.12 GPU镜像\n",
    "image_uri = retrieve(\n",
    "    \"pytorch\",\n",
    "    framework_version=\"1.12\",\n",
    "    accelerator_type=\"GPU\",\n",
    ").image_uri\n",
    "print(\"Training Image URI: \", image_uri)\n",
    "\n",
    "\n",
    "# 每一个机器实例的GPU数量，需要根据用户选择的机器型号(instance_type)进行修改\n",
    "gpu_count_per_instance = 2\n",
    "\n",
    "# 训练脚本使用torchrun命令启动\n",
    "command = f\"\"\"torchrun --master_addr=$MASTER_ADDR \\\n",
    "--master_port=$MASTER_PORT \\\n",
    "--nnodes=$WORLD_SIZE --node_rank=$RANK \\\n",
    "--nproc_per_node={gpu_count_per_instance} \\\n",
    "train_multinode.py --total_epochs 10 --save_every 5 \\\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "# 提交训练作业\n",
    "est = Estimator(\n",
    "    image_uri=image_uri,\n",
    "    source_dir=\"./train_src\",  # 训练代码所在目录\n",
    "    command=command,\n",
    "    job_type=\"PyTorchJob\",\n",
    "    instance_type=\"ecs.gn6i-c24g1.12xlarge\",  # 2 * NVIDIA T4 GPU\n",
    "    instance_count=2,  # 机器实例数量\n",
    "    base_job_name=\"pytorch-ddp\",\n",
    ")\n",
    "\n",
    "# fit方法提交训练作业，默认等待到作业执行完成\n",
    "est.fit()\n",
    "\n",
    "\n",
    "# 查看作业的输出模型\n",
    "\n",
    "est.model_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参考：\n",
    "\n",
    "- PyTorch Distributed Overview: https://pytorch.org/tutorials/beginner/dist_overview.html"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pai-dev-py38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
